"""
The code is released exclusively for review purposes with the following terms:
PROPRIETARY AND CONFIDENTIAL. UNAUTHORIZED USE, COPYING, OR DISTRIBUTION OF THE 
CODE, VIA ANY MEDIUM, IS STRICTLY PROHIBITED. BY ACCESSING THE CODE, THE 
REVIEWERS AGREE TO DELETE THEM FROM ALL MEDIA AFTER THE REVIEW PERIOD IS OVER.
"""

import  numpy as np
from tqdm import tqdm
from numpy.random import multivariate_normal
# from tensorflow import keras
# import tensorflow as tf
# tf.keras.backend.set_floatx('float64')
# tf.compat.v1.enable_eager_execution()
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.covariance import EmpiricalCovariance

from sklearn.metrics import mean_absolute_error
from sklearn.linear_model import LinearRegression
from scipy.optimize import lsq_linear
from sklearn.metrics import pairwise_distances
from IPython.display import display, Markdown 
# from julia import Base
# from julia import ConstrainedLasso as CL
from utils import lime_explanation

def preproc_X_y(data_tuple_train, X_preproc, y_preproc):
    """ Use of sample_weight is probably buggy right now
    """
    
    num_envs = len(data_tuple_train)
    data_tuple_train_pp = []
    
    for env in range(num_envs):
        if X_preproc["rescale_by_weights"]:
            sample_weight = data_tuple_train[env][2]
        else:
            sample_weight = None
            
        if X_preproc is not None:
            X_preproc_obj = X_preproc["cls"](**X_preproc["opts"])
            X = X_preproc_obj.fit_transform(
                            data_tuple_train[env][0],
                            sample_weight=sample_weight)
        else:
            X_preproc_obj = None
            X = data_tuple_train[env][0]
            
        if y_preproc["rescale_by_weights"]:
            sample_weight = data_tuple_train[env][2]
        else:
            sample_weight = None
            
        if y_preproc is not None:
            y_preproc_obj = y_preproc["cls"](**y_preproc["opts"])
            y = y_preproc_obj.fit_transform(
                            data_tuple_train[env][1],
                            sample_weight=sample_weight)
        else:
            y_preproc_obj = None
            y = data_tuple_train[env][1]
            
        data_tuple_train_pp.append((
            X, y, data_tuple_train[env][2],
            X_preproc_obj, y_preproc_obj))
    
    return data_tuple_train_pp

def display_data_stats(data_tuple_train):
    num_envs = len(data_tuple_train)
    
    for i in range(num_envs):
        display(Markdown("**Environment %d**" %(i)))
        data = data_tuple_train[i][0]
        labels = data_tuple_train[i][1].ravel()
        weights = data_tuple_train[i][2]

        display(Markdown("X Mean = " + str(data.mean(axis=0))))
        display(Markdown("X Variance = " + str(data.var(axis=0))))
        ec = EmpiricalCovariance().fit(data)
        display(Markdown("X Covariance = " + str(ec.covariance_)))
        
        display(Markdown("y Mean = " + str(labels.mean(axis=0))))
        display(Markdown("y Variance = " + str(labels.var(axis=0))))
        
        display(Markdown("Weights = " +str(weights)))
        display(Markdown("Mean of Weights = " +str(weights.mean())))        
        plt.figure()
        plt.hist(weights)
        
def simple_preproc_X_y(data_tuple_train):
    """reweight, remove mean"""
    
    num_envs = len(data_tuple_train)
    data_tuple_train_pp = []
    
    for env in range(num_envs):
        data = data_tuple_train[env][0]
#         data = np.hstack([data_tuple_train[env][0],
#                    np.ones((len(data_tuple_train[env][0]), 1))])
        labels = data_tuple_train[env][1].ravel()
        weights = data_tuple_train[env][2]
        
        weighted_data = ((data - np.average(data, axis=0, weights=weights))
                             * (weights[:, np.newaxis]))
        weighted_labels = ((labels - np.average(labels, weights=weights))
                                       * (weights)).reshape(-1, 1)
        
        data_tuple_train_pp.append((
            weighted_data, weighted_labels, weights))
    
    return data_tuple_train_pp

def get_weights(X_ref, X, kernel_width=None):

    d = X.shape[1]
    
    if kernel_width is None:
        kernel_width = np.sqrt(d) * .75
    kernel_width = float(kernel_width)
    kernel = lambda dists: np.sqrt(np.exp(-(dists ** 2) / kernel_width ** 2))  
    
    dists = pairwise_distances(X_ref.reshape(1, -1), X)

    return kernel(dists).ravel()

def create_multivariate_normal_envs(fun,
                                    centers, variances, 
                                    ne_train, ne_test,
                                    rand_seed, kernel_widths):
    data_tuple_train = []
    data_tuple_test = []
    (num_envs, d) = centers.shape
    
    for i in range(len(variances)):
        for j in range(len(kernel_widths)):
            k = i*len(kernel_widths)+j
            mean = centers[k]
            cov = variances[i]*np.eye(d)

            X_train = np.random.RandomState(
                    seed=rand_seed+k).multivariate_normal(
                    mean, cov, size=ne_train[k])
            y_train = (fun(X_train.transpose())).astype(np.float64).reshape(-1, 1)
            w_train = get_weights(mean, X_train, 
                                  kernel_width=kernel_widths[j])
            data_tuple_train.append([X_train, y_train, w_train])

            X_test = np.random.RandomState(
                    seed=rand_seed-k-1).multivariate_normal(
                    mean, cov, size=ne_test[k])
            y_test = (fun(X_test.transpose())).astype(np.float64).reshape(-1, 1)
            w_test = get_weights(mean, X_test, 
                                 kernel_width=kernel_widths[j])
            data_tuple_test.append([X_test, y_test, w_test])

    return data_tuple_train, data_tuple_test

def erm_individual(data_tuple_train, data_tuple_test,
                   lr_options={}):

    num_envs = len(data_tuple_train)
    
    ys_train = []
    ys_test = []
    ypreds_train = []
    ypreds_test = []
    ws_train = []
    ws_test = []

    lr_envs = []
    
    for i in range(num_envs):
        
        if lr_options["rescale_by_weights"]:
            sample_weight=data_tuple_train[i][2]**2.0
        else:
            sample_weight=None
            
        lr_env = LinearRegression(**lr_options["init_opts"]).fit(
                    data_tuple_train[i][0], 
                    data_tuple_train[i][1],
                    sample_weight=sample_weight
                    )
        lr_envs.append(lr_env)

        display(Markdown("Env " + str(i)+ " weights: " + str(lr_env.coef_)))
        display(Markdown("Env " + str(i)+ " intercept: " + str(lr_env.intercept_)))

        ypred_train = np.dot(data_tuple_train[i][0], 
                             lr_env.coef_.transpose()) + lr_env.intercept_
        ypred_test = np.dot(data_tuple_test[i][0], 
                             lr_env.coef_.transpose()) + lr_env.intercept_
        
        ys_train.append(data_tuple_train[i][1].ravel())
        ys_test.append(data_tuple_test[i][1].ravel())

        ws_train.append(data_tuple_train[i][2].ravel())
        ws_test.append(data_tuple_test[i][2].ravel())
        
        ypreds_train.append(ypred_train.ravel())
        ypreds_test.append(ypred_test.ravel())       
    
    if lr_options["rescale_by_weights"]:
        sample_weight_train = np.hstack(ws_train)
        sample_weight_test = np.hstack(ws_test)
    else:
        sample_weight_train = None
        sample_weight_test = None
            
    display(Markdown("Total train MAE: %.4f" % (mean_absolute_error(
                        np.hstack(ys_train),
                        np.hstack(ypreds_train),
                        sample_weight=sample_weight_train))))

    display(Markdown("Total test MAE: %.4f" % (mean_absolute_error(
                        np.hstack(ys_test),
                        np.hstack(ypreds_test),
                        sample_weight=sample_weight_test))))

    return lr_envs

def erm_union(data_tuple_train, data_tuple_test, lr_options={}):

    num_envs = len(data_tuple_train)
    
    ys_train = []
    ys_test = []
    Xs_train = []
    Xs_test = []
    ws_train = []
    ws_test = []
    
    for i in range(num_envs):
        
        Xs_train.append(data_tuple_train[i][0])
        ys_train.append(data_tuple_train[i][1].ravel())
        ws_train.append(data_tuple_train[i][2].ravel())
        
        Xs_test.append(data_tuple_test[i][0])
        ys_test.append(data_tuple_test[i][1].ravel())
        ws_test.append(data_tuple_test[i][2].ravel())
        
    X_train = np.vstack(Xs_train)
    y_train = np.hstack(ys_train)
    w_train = np.hstack(ws_train)
    
    X_test = np.vstack(Xs_test)
    y_test = np.hstack(ys_test)
    w_test = np.hstack(ws_test)
    
    lr_all_env = LinearRegression(**lr_options["init_opts"]).fit(
                X_train, 
                y_train.reshape(-1,1),
                sample_weight=w_train**2.0)
    ypred_train = lr_all_env.predict(X_train).ravel()
    ypred_test = lr_all_env.predict(X_test).ravel()

    display(Markdown("Env_all weights: " + str(lr_all_env.coef_)))
    display(Markdown("Env_all intercept: " + str(lr_all_env.intercept_)))

    if lr_options["rescale_by_weights"]:
        sample_weight_train = w_train
        sample_weight_test = w_test
    else:
        sample_weight_train = None
        sample_weight_test = None
        
    display(Markdown("Total train MAE: %.4f" % (mean_absolute_error(
                        y_train, ypred_train,
                        sample_weight=sample_weight_train))))

    display(Markdown("Total test MAE: %.4f" % (mean_absolute_error(
                        y_test, ypred_test,
                        sample_weight=sample_weight_test))))
    
    return lr_all_env

def prep_data_lrg_lsq(data_tuple_train, 
                        rescale_by_weights=False,
                        fit_intercept=False):

    num_envs = len(data_tuple_train)
    ne_train = np.array([data_tuple_train[i][0].shape[0]
              for i in range(num_envs)])
    
    if fit_intercept:
        p = data_tuple_train[0][0].shape[1]+1
    else:
        p = data_tuple_train[0][0].shape[1]
    
    # Create Xall and yall
    Xall = [data_tuple_train[i][0]
            for i in range(num_envs)]
    yall = [data_tuple_train[i][1].ravel() 
            for i in range(num_envs)]
    
    # Intercept
    if fit_intercept:
        Xall = [np.hstack([Xall[i], np.ones((ne_train[i], 1))])
                for i in range(num_envs)]
    
    # Rescale by weights
    if rescale_by_weights:
        Xall = [data_tuple_train[i][2].reshape(-1, 1)*
                Xall[i]
                for i in range(num_envs)]
        yall = [data_tuple_train[i][2]*yall[i] 
                for i in range(num_envs)]        

    return ne_train, num_envs, p, Xall, yall

def lrg_lsq_sparse_old(data_tuple_train, w_all, lrg_config):
    from julia import Base
    from julia import ConstrainedLasso as CL

    # data_tuple_train = data_train_env

    ne_train, num_envs, p, Xall, yall = prep_data_lrg_lsq(data_tuple_train,
                        rescale_by_weights=lrg_config["rescale_by_weights"],
                        fit_intercept=lrg_config["fit_intercept"])
    n = len(yall[0])

    w_all_iters = np.zeros((num_envs*p, 
                            lrg_config["num_iters"]+1))
    mae_iters = np.zeros(lrg_config["num_iters"]+1)

    # Init to 0 just to make sure the soln is not biased towards
    # any env
    # w_all = np.zeros((num_envs, p))
    mae_iters[0] = perf_lrg_lsq(yall, Xall, w_all.sum(axis=0))

    # Record initial values
    w_all_iters[:, 0] = w_all.ravel()
    mask_arr = np.ones(num_envs, dtype=bool)

    # Create necessary matrices
    bineq = np.hstack((lrg_config["bound"]*np.ones(p),
                      lrg_config["bound"]*np.ones(p)))
    Aineq = np.vstack((np.eye(p), -1.0*np.eye(p)))

    # Alternate between the environments
    for iteri in range(lrg_config["num_iters"]):
#         print("iter = ", iteri)
#         print("=======")
        
        env_list = np.arange(num_envs)
        if lrg_config["randomize_iterations"]:
            np.random.shuffle(env_list)
        
        for env in env_list:
            mask_arr[env] = False
            ypartial = yall[env] - np.sum(
                        np.dot(Xall[env], w_all[mask_arr].transpose()), 
                        axis = 1)
            mask_arr[env] = True

            result = CL.lsq_classopath(Xall[env], ypartial, 
            #                         Aeq = Aeq, beq = beq,
                                    Aineq = Aineq, bineq = bineq
                                    )
            
            (betapath, rhopath, objvalpath, lambdapatheq,
                 mupathineq, dfpath, violationspath) = result
                 
#             print("env = ", env)
#             print("**********")
#             print(rhopath)
#             print((betapath != 0.0).sum(axis=0))
#             print(np.diff(betapath != 0.0, axis=1).sum(axis=0))
            
            betaind = np.where((betapath != 0.0).sum(axis=0) 
                               <= lrg_config["num_nonzeros"])[0][-1]

            w_all[env] = betapath[:, betaind]


        w_all_iters[:, iteri+1] = w_all.ravel()
        w_lrg_lsq = w_all.sum(axis=0)

        mae_iters[iteri+1] = perf_lrg_lsq(yall, Xall, w_lrg_lsq)
        
    return w_lrg_lsq, w_all_iters, mae_iters

def lrg_lsq(data_tuple_train, lrg_config):


    ne_train, num_envs, p, Xall, yall = prep_data_lrg_lsq(data_tuple_train,
                        rescale_by_weights=lrg_config["rescale_by_weights"],
                        fit_intercept=lrg_config["fit_intercept"])
    n = len(yall[0])

    w_all_iters = np.zeros((num_envs*p, 
                            lrg_config["num_iters"]+1))
    mae_iters = np.zeros(lrg_config["num_iters"]+1)

    # Init to 0 just to make sure the soln is not biased towards
    # any env
    w_all = np.zeros((num_envs, p))
    mae_iters[0] = perf_lrg_lsq(yall, Xall, w_all.sum(axis=0))

    # Record initial values
    w_all_iters[:, 0] = w_all.ravel()
    mask_arr = np.ones(num_envs, dtype=bool)

    if lrg_config["ridge"]:
        yw_aug = np.zeros(n+p)
        Xw_aug = np.zeros((n+p, p))
        Xw_aug[n:n+p, :] = lrg_config["ridge_penalty_multiplier"]*np.eye(p)
    else:
        yw_aug = np.zeros(n)
        Xw_aug = np.zeros((n, p))
    

    # Alternate between the environments
    for iteri in range(lrg_config["num_iters"]):

        for env in range(num_envs):
            mask_arr[env] = False
            ypartial = yall[env] - np.sum(
                        np.dot(Xall[env], w_all[mask_arr].transpose()), 
                        axis = 1)
            mask_arr[env] = True

            if lrg_config["ridge"]:
                yw_aug[:n] = ypartial
                Xw_aug[:n, :] = Xall[env]
            else:
                np.copyto(yw_aug, ypartial)
                np.copyto(Xw_aug, Xall[env])

            result = lsq_linear(Xw_aug, yw_aug, 
                       bounds=(-lrg_config["bound"], lrg_config["bound"]), 
                       max_iter=lrg_config["max_iter"],
                       lsq_solver="exact")

            w_all[env] = result.x


        w_all_iters[:, iteri+1] = w_all.ravel()
        w_lrg_lsq = w_all.sum(axis=0)

        mae_iters[iteri+1] = perf_lrg_lsq(yall, Xall, w_lrg_lsq)

    return w_lrg_lsq, w_all_iters, mae_iters
    
def perf_lrg_lsq(yall, Xall, w_lrg_lsq):

    num_envs = len(Xall)
    
    yall_vec = np.hstack(yall)
    yall_pred_vec = np.hstack([np.dot(Xall[env], w_lrg_lsq) 
                               for env in range(num_envs)])
    return mean_absolute_error(yall_vec, yall_pred_vec)

def mae_lrg_lsq(data_tuple_train, w_lrg_lsq, lrg_config):
    
    _, _, p, Xall, yall = prep_data_lrg_lsq(data_tuple_train,
                    rescale_by_weights=lrg_config["rescale_by_weights"],
                    fit_intercept=lrg_config["fit_intercept"])
    return perf_lrg_lsq(yall, Xall, w_lrg_lsq)

def lrg_bound(lr_indiv):

    from itertools import chain

    all_coefs = []
    for lr_indivi in lr_indiv:
        all_coefs.append(list(lr_indivi.coef_.ravel()))
        all_coefs.append(list(np.array([lr_indivi.intercept_]).ravel()))

    return 3.0*np.max(np.abs(list(chain(*all_coefs))))

def infer_lrg_from_erm(lr_indiv):
    num_envs = len(lr_indiv)
    num_feats = len(lr_indiv[0].coef_[0])

    all_coefs = np.zeros((num_envs, num_feats+1))
    all_intercepts = np.zeros(num_envs)

    lrg_coef_inferred = np.zeros(num_feats+1)

    for env in range(num_envs):
        all_coefs[env, :-1] = lr_indiv[env].coef_[0]
        all_coefs[env, -1] = np.array([lr_indiv[env].intercept_]).ravel()[0]
    # lr_indiv[0].coef_.ravel()
    # if num_envs == 2:

    for f in range(num_feats+1):
        lrg_coef_inferred[f] = infer_lrg_coef(all_coefs[:, f].ravel())
    
    return lrg_coef_inferred

def infer_lrg_coef(fc):
    sortorder = np.argsort(fc)
    fc = fc[sortorder]

    if len(fc) % 2 == 0:

        c1 = fc[(len(fc)//2) - 1]
        c2 = fc[len(fc)//2]

        if np.sign(c1) == np.sign(c2):
            return np.sign(c1) * np.minimum(np.abs(c1), np.abs(c2))
        else:
            return 0.0
    else:
        return fc[len(fc)//2]

def perf_envwise_overall(Xl, yl, weights, coefs):
    """ Environment wise and overall performance measures across all environments
    X - input matrix (same for all envs)
    y - outcomes (same for all envs)
    weights - weights (different for the different envs)
    coefs - coefficients of the environments"""

    yacts = []
    ypreds = []
    ws = weights.ravel()

    for X, y, w, c in zip(Xl, yl, weights, coefs):
        ypred = np.dot(X, c)

        yacts.append(y.ravel())
        ypreds.append(ypred.ravel())

    return mean_absolute_error(np.hstack(yacts), 
                np.hstack(ypreds),
                sample_weight=ws)

def lrg_lsq_sparse(data_tuple_train, w_envs_current, lrg_config):
    """ We control the sparsity of the overall solution"""
    from julia import Base
    from julia import ConstrainedLasso as CL

    ne_train, num_envs, p, Xall, yall = prep_data_lrg_lsq(data_tuple_train,
                        rescale_by_weights=lrg_config["rescale_by_weights"],
                        fit_intercept=lrg_config["fit_intercept"])
    n = len(yall[0])

    w_all_iters = np.zeros((p, 
                            lrg_config["num_iters"]+1))
    mae_iters = np.zeros(lrg_config["num_iters"]+1)
    w_current = w_envs_current.sum(axis=0)

    mae_iters[0] = perf_lrg_lsq(yall, Xall, w_current)

    # Record initial values
    w_all_iters[:, 0] = w_current
    Aineq = np.vstack((np.eye(p), -1.0*np.eye(p)))

    # Alternate between the environments
    for iteri in range(lrg_config["num_iters"]):
#         print("iter = ", iteri)
#         print("=======")
        
        env_list = np.arange(num_envs)
        if lrg_config["randomize_iterations"]:
            np.random.shuffle(env_list)
        
        for env in env_list:
            w_i_current = w_envs_current[env]
            w_not_i_current = w_current-w_i_current
            
            bineq = np.hstack((
                        np.abs(lrg_config["bound"]*np.ones(p)+w_not_i_current)+1e-6,
                        np.abs(lrg_config["bound"]*np.ones(p)+w_not_i_current)+1e-6))

            result = CL.lsq_classopath(Xall[env], yall[env], 
                                    Aineq = Aineq, bineq = bineq
                                    )
            
            (betapath, rhopath, objvalpath, lambdapatheq,
                 mupathineq, dfpath, violationspath) = result

            # result = lsq_linear(Xall[env], yall[env], 
            #            bounds=(-lrg_config["bound"], lrg_config["bound"]), 
            #            max_iter=lrg_config["max_iter"],
            #            lsq_solver="exact")         

#             print("env = ", env)
#             print("**********")
#             print(rhopath)
#             print((betapath != 0.0).sum(axis=0))
#             print(np.diff(1.0*(betapath != 0.0), axis=1).sum(axis=0))
            
            # pick the first index where the coefficient enters the path
            betaind1 = np.where((betapath != 0.0).sum(axis=0) 
                               == lrg_config["num_nonzeros"])[0]

            if betaind1.size > 0:
                betaind = betaind1[0]
            else:
                betaind = np.argmin(np.abs((betapath != 0.0).sum(axis=0)
                                  - lrg_config["num_nonzeros"]))                

            # Update the various coefficients
            # w_current = result.x
            w_current = betapath[:, betaind]
            w_i_current = w_current - w_not_i_current
            w_envs_current[env] = w_i_current
            
        w_all_iters[:, iteri+1] = w_current

        mae_iters[iteri+1] = perf_lrg_lsq(yall, Xall, w_current)
        
    return w_current, w_all_iters, mae_iters, w_envs_current

def linf_project(x, c):
    x[x >= c] = c[x >= c]
    x[x <= -c] = -c[x <= -c]
    
    return x

def lrg_lsq_sparse_simple(data_tuple_train, w_envs_current, lrg_config):
    """ We control the sparsity of the overall solution"""


    ne_train, num_envs, p, Xall, yall = prep_data_lrg_lsq(data_tuple_train,
                        rescale_by_weights=lrg_config["rescale_by_weights"],
                        fit_intercept=lrg_config["fit_intercept"])
    n = len(yall[0])

    w_all_iters = np.zeros((p, 
                            lrg_config["num_iters"]+1))
    mae_iters = np.zeros(lrg_config["num_iters"]+1)
    w_current = w_envs_current.sum(axis=0)
    sample_weights = np.ones(n)

    mae_iters[0] = perf_lrg_lsq(yall, Xall, w_current)

    # Record initial values
    w_all_iters[:, 0] = w_current
    Aineq = np.vstack((np.eye(p), -1.0*np.eye(p)))

    # Alternate between the environments
    for iteri in range(lrg_config["num_iters"]):
#         print("iter = ", iteri)
#         print("=======")
        
        env_list = np.arange(num_envs)
        if lrg_config["randomize_iterations"]:
            np.random.shuffle(env_list)
        
        for env in env_list:
            w_i_current = w_envs_current[env]
            w_not_i_current = w_current-w_i_current
            
#             bineq = np.hstack((
#                         np.abs(lrg_config["bound"]*np.ones(p)+w_not_i_current)+1e-6,
#                         np.abs(lrg_config["bound"]*np.ones(p)+w_not_i_current)+1e-6))
            bounds = np.abs(lrg_config["bound"]*np.ones(p)+w_not_i_current)+1e-6

#             result = CL.lsq_classopath(Xall[env], yall[env], 
#                                     Aineq = Aineq, bineq = bineq
#                                     )
            
#             (betapath, rhopath, objvalpath, lambdapatheq,
#                  mupathineq, dfpath, violationspath) = result

            w_current = lime_explanation(Xall[env], yall[env], sample_weights, 
                             num_nonzeros=lrg_config["num_nonzeros"], 
                             debias=False)
            w_current = linf_project(w_current, bounds)

#             print("env = ", env)
#             print("**********")
#             print(rhopath)
#             print((betapath != 0.0).sum(axis=0))
#             print(np.diff(1.0*(betapath != 0.0), axis=1).sum(axis=0))
            
#             # pick the first index where the coefficient enters the path
#             betaind1 = np.where((betapath != 0.0).sum(axis=0) 
#                                == lrg_config["num_nonzeros"])[0]

#             if betaind1.size > 0:
#                 betaind = betaind1[0]
#             else:
#                 betaind = np.argmin(np.abs((betapath != 0.0).sum(axis=0)
#                                   - lrg_config["num_nonzeros"]))                

            # Update the various coefficients
            # w_current = result.x
#             w_current = betapath[:, betaind]
            w_i_current = w_current - w_not_i_current
            w_envs_current[env] = w_i_current
            
        w_all_iters[:, iteri+1] = w_current

        mae_iters[iteri+1] = perf_lrg_lsq(yall, Xall, w_current)
        
    return w_current, w_all_iters, mae_iters, w_envs_current